// SPDX-License-Identifier: GPL-2.0
/*
 * BiscuitOS SPECIAL DEVICE PAGE on GUP
 *
 * (C) 2024.04.10 BuddyZhang1 <buddy.zhang@aliyun.com>
 * (C) 2024.04.10 BiscuitOS
 */
#include <linux/init.h>
#include <linux/kernel.h>
#include <linux/module.h>
#include <linux/miscdevice.h>
#include <linux/fs.h>
#include <linux/slab.h>
#include <linux/mm.h>

#define DEV_NAME		"BiscuitOS"
#define HPAGE_SIZE		(2 * 1024 * 1024) /* 2MiB */
#define MMIO_BASE		0x80000

static struct BiscuitOS_page_pool {
	/* PAGE POOL BASE */
	struct page *mem_map;
	/* OFFSET VIRT */
	unsigned long vm_start;
} bpool;

/* MMIO */
static struct resource BiscuitOS_mmio_res = {
	.name	= "BiscuitOS_P2P_MMIO",
	.flags	= IORESOURCE_IO,
};

/** ALLOC ZONE_DEVICE SPECIAL PAGE **/

static struct page *BiscuitOS_find_special_page(struct vm_area_struct *vma,
                                          unsigned long addr)
{
	struct page *page;

	if (addr < bpool.vm_start)
		return NULL;

	page = bpool.mem_map + ((addr - bpool.vm_start) / PAGE_SIZE);

	return page;
}

static struct vm_operations_struct BiscuitOS_vm_ops = {
	.find_special_page = BiscuitOS_find_special_page,
}; 

static int BiscuitOS_mmap(struct file *filp, struct vm_area_struct *vma)
{
	unsigned long mmio_base = vma->vm_pgoff << PAGE_SHIFT;
	resource_size_t size = vma->vm_end - vma->vm_start;
	unsigned long page_size, addr, pfn;
	struct page *page;

	/* BUILD SPECIAL PAGE POOL */
	page_size = (size / PAGE_SIZE) * sizeof(struct page) / PAGE_SIZE;
	page = alloc_pages(GFP_KERNEL, ilog2(page_size));
	if (!page) {
		printk("ERROR: PAGE ALLOC\n");
		return -ENOMEM;
	}
	bpool.mem_map = page_to_virt(page);
	bpool.vm_start = vma->vm_start;

	/* REGISTER MMIO REGION */
	BiscuitOS_mmio_res.start = mmio_base;
	BiscuitOS_mmio_res.end   = mmio_base + size;
	if (request_resource(&iomem_resource, &BiscuitOS_mmio_res) < 0) {
		printk("ERROR: MMIO BAD\n");
		return -EINVAL;
	}

	for (addr = vma->vm_start, pfn = BiscuitOS_mmio_res.start >> PAGE_SHIFT;
		addr < vma->vm_end; addr += PAGE_SIZE, pfn++) {
		/* SETUP SPECIAL PAGE */
		page = bpool.mem_map + ((addr - bpool.vm_start) / PAGE_SIZE);
		atomic_set(&page->_refcount, 2); /* ONLY ONE REFERENCE */
		atomic_set(&page->_mapcount, 0); /* ONLY ONE MAPPED */
		remap_pfn_range(vma, addr, pfn, PAGE_SIZE, vma->vm_page_prot);
	}

	/* SPECIAL VMA & PAGE */
	vma->vm_flags &= ~(VM_IO | VM_PFNMAP);
	vma->vm_ops = &BiscuitOS_vm_ops;

	return 0;
}

static unsigned long BiscuitOS_get_unmapped_area(struct file *filp,
                unsigned long uaddr, unsigned long len,
                unsigned long pgoff, unsigned long flags)
{
	unsigned long align_addr;

	align_addr = current->mm->get_unmapped_area(NULL, 0,
					len + HPAGE_SIZE, 0, flags);
	/* Aligned on 2MiB */
	align_addr = round_up(align_addr, HPAGE_SIZE);

	return align_addr;
}

/** INTERFACE TEST **/

static ssize_t BiscuitOS_write(struct file *filp, const char __user *buf,
			size_t len, loff_t *offset)
{
	unsigned long addr = (unsigned long)buf;
	struct page *page;
	unsigned long end = 512 * 1024 * 1024;
	unsigned long start;

	if (page == bpool.mem_map)
		printk("FIND SPECIAL PAGE.\n");

	for (start = addr; start < addr + end; start += PAGE_SIZE) {
		get_user_pages(start, 1, FOLL_GET, &page, NULL);

		if (page)
			printk("PAGE PFN: %#lx\n", 
			  (page - bpool.mem_map) + MMIO_BASE);
	}

	return len;
}

/** MISC DEVICE **/

static struct file_operations BiscuitOS_fops = {
	.owner		= THIS_MODULE,
	.write		= BiscuitOS_write,
	.mmap		= BiscuitOS_mmap,
	.get_unmapped_area = BiscuitOS_get_unmapped_area,
};

static struct miscdevice BiscuitOS_drv = {
	.minor	= MISC_DYNAMIC_MINOR,
	.name	= DEV_NAME,
	.fops	= &BiscuitOS_fops,
};

static int __init BiscuitOS_init(void)
{
	return misc_register(&BiscuitOS_drv);
}

static void __exit BiscuitOS_exit(void)
{
	remove_resource(&BiscuitOS_mmio_res);
	misc_deregister(&BiscuitOS_drv);
}

module_init(BiscuitOS_init);
module_exit(BiscuitOS_exit);

MODULE_LICENSE("GPL");
MODULE_AUTHOR("BiscuitOS <buddy.zhang@aliyun.com>");
MODULE_DESCRIPTION("BiscuitOS SPEICAL PAGE");
